from collections import defaultdict
from functools import lru_cache
from typing import Tuple, Dict, Any, List

import numpy as np

from centralized_verification.shields.shield import Shield, T, ShieldResult, S, AgentResult, AgentUpdate
from centralized_verification.shields.slugs_shielding.combine_identical_states import ShieldSpec
from centralized_verification.shields.slugs_shielding.label_extractor import LabelExtractor


class SlugsCentralizedShield(Shield[T, Tuple[int, Any]]):
    def __init__(self, env: T, shield_spec: ShieldSpec, label_extractor: LabelExtractor, **kwargs):
        super().__init__(env, **kwargs)
        self.shield_spec = shield_spec
        self.label_extractor = label_extractor

    def get_actual_state_from_candidates(self, possible_states, label):
        for state_num in possible_states:
            if self.shield_spec[state_num].label == label:
                return state_num

        raise Exception("The shield did not accurately model the environment")

    def get_safe_action_set_and_possible_states(self, state):
        return self.shield_spec[state].actions

    @lru_cache()
    def _get_label_to_initial_shield_state_dict(self) -> Dict[Any, List[int]]:
        ret = defaultdict(list)
        for state_num, state in self.shield_spec.items():
            if state.initial_state:
                ret[state.label].append(state_num)

        return ret

    def get_initial_shield_state(self, state, initial_joint_obs) -> S:
        return -1

    def evaluate_joint_action(self, state, joint_obs, proposed_action, shield_state: T) -> Tuple[ShieldResult, S]:
        label = self.label_extractor(state)

        if shield_state == -1:
            possible_shield_states = self._get_label_to_initial_shield_state_dict()[label]
            if len(possible_shield_states) != 1:
                raise Exception("Shield has more than one possible start state for the given label. "
                                "Override get_initial_shield_state to specify the desired behavior")

            current_shield_automaton_state_num = possible_shield_states[0]
        else:
            prev_shield_aut_state, prev_action = shield_state
            possible_shield_states = self.shield_spec[prev_shield_aut_state].actions[prev_action]
            current_shield_automaton_state_num = self.get_actual_state_from_candidates(possible_shield_states, label)

        allowed_actions = set(self.shield_spec[current_shield_automaton_state_num].actions.keys())

        shield_result = [AgentResult(AgentUpdate(action=action)) for action in proposed_action]

        if proposed_action in allowed_actions:
            return shield_result, (current_shield_automaton_state_num, proposed_action)

        # noinspection PyTypeChecker
        priority: List[int] = np.random.permutation(np.arange(len(proposed_action))).tolist()  # Agents in random order

        # Try setting each agent's actions to zero until we get to something safe
        proposed_action_list = list(proposed_action)
        for agent_to_neuter in priority:  # Set individual agents to the default action until the joint action is safe

            proposed_action_list[agent_to_neuter] = 0
            shield_result[agent_to_neuter] = self.replace_action_agent_result(proposed_action[agent_to_neuter], 0)

            if tuple(proposed_action_list) in allowed_actions:
                return shield_result, (current_shield_automaton_state_num, tuple(proposed_action_list))

        # Just pick _some_ safe action
        if len(allowed_actions) > 0:
            taken_action = next(iter(allowed_actions))

            shield_result = [AgentResult(AgentUpdate(
                action=taken_indiv_action)) if taken_indiv_action == proposed_indiv_action else self.replace_action_agent_result(
                proposed_indiv_action, taken_indiv_action) for proposed_indiv_action, taken_indiv_action in
                             zip(proposed_action, taken_action)]
            return shield_result, (current_shield_automaton_state_num, taken_action)

        assert False, "There aren't any safe actions, the input wasn't a good shield"
